# -*- encoding: utf-8 -*-
'''
@File    :   dataSynthesis.py
@Time    :   2023/09/02 09:56:52
@Author  :   chenyang
@Version :   1.0
@Desc    :   使用模版+元数据合成训练样本
'''

import os, sys
sys.path.insert(0, os.getcwd())
import re
import random
import json
import argparse
from tqdm import tqdm
from typing import List, Dict, Any, Tuple

from utils import split_Mix_word

random.seed(10010)

class DataSynthesis:
    def __init__(self, templatePath:str, datasetDir:str) -> None:
        # 合成数据的模板路径
        self.templatePath = templatePath
        # 合成数据的目录
        self.datasetDir = datasetDir
        # 合成数据的源语料目录
        self.baseCorpusDir = "dataGenerate/corpus"
        # 需要合成的场景,对应template.json中的意图类别
        self.intentList = ["TranslationEnZh","TranslationZhEn","IdiomExplanation",
                           "CreateSentence","Antonym", "Synonym","GroupWord"]

    def convert2slotLabel(self, sentence:str, slotName:str) ->str:
        """转换句子或词转换为对应的Solt标签"""
        words = split_Mix_word(sentence)
        if "#" in sentence:     #这里#用于标记slot位置，无特殊意义。
            slotLabel = [slotName  if word!="#" else word for word in words]
        else:
            slotLabel = [f"B-{slotName}" if i==0 else f"I-{slotName}" for i in range(len(words))]
        return " ".join(slotLabel)
    
    def __generate_label(self, corpusList, intent, intentTemplates, sampleN, corpusName, slotDict):
        """生成标签数据"""
        samples = []
        for item in tqdm(corpusList, desc=intent):  #源数据通过模板进行数据扩展
            item = item.strip("-").strip()
            # 生成slot标签, 这里生成的都是单槽位的数据  样例: template="翻译:{TransEnZhSentence}", 
            template = random.choice(intentTemplates)
            pattern = r".*\{(.*)\}.*"
            slotName = re.findall(pattern, template)[0]
            slotLabel0 = self.convert2slotLabel(item, slotName)
            
            template1 = template.replace("{%s}"%slotName, "#")
            slotLabel1 = self.convert2slotLabel(template1, slotName="O")
            slotLabel = slotLabel1.replace("#", slotLabel0)
            
            sampleText = template.replace("{%s}"%slotName, item)
            assert len(slotLabel.split(" "))==len(split_Mix_word(sampleText)), "分词和slot id数量不对应"
            samples.append((intent, sampleText, slotLabel))
            
            if not slotDict.get(slotName): slotDict.update({slotName:None})
            #为防止单个源数据语料巨大(5w+), 导致单一样本占比过大。增加了单个语料的采样样本数量限制
            if len(samples) > sampleN:  
                print(f"number of samples={len(samples)}, generated by {corpusName}")
                break
        return samples
    
    def __synthsis_training_data(self, templateDict:Dict, slotDict:Dict, intent:str, sampleN:int=10) -> List:
        """自动合成意图识别和槽位填充的数据和标签"""
        datalist = []
        synthesisDict = templateDict[intent]
        intentTemplates = synthesisDict["template"]
        intentCorpusNames = synthesisDict["corpus"]
        
        for corpusName in intentCorpusNames:
            corpusPath = os.path.join(self.baseCorpusDir, corpusName)
            assert os.path.exists(corpusPath), f"Not found {corpusPath}, please check."
            with open(corpusPath) as f: corpusList = f.readlines()
            samples = self.__generate_label(corpusList, intent, intentTemplates, sampleN, corpusName, slotDict)
            # 合并多个源数据合成的数据
            datalist += samples
        return datalist
    
    def __split_dataset(self, datalist:List, testRatio=0.1) -> Tuple[List, List, List]:
        """train:dev:test = 8:1:1 """
        random.shuffle(datalist)
        
        n = int(testRatio * len(datalist))
        testList = datalist[:n]
        devList  = datalist[n:2*n]
        trainList = datalist[2*n:]
        return trainList, devList, testList
    
    def __call__(self, sampleN:int, *args: Any, **kwds: Any) -> None:
        assert os.path.exists(self.templatePath), f"Not found {self.templatePath}."
        with open(self.templatePath) as f: templateDict = json.load(f)
        
        # 合成每个任务场景的数据
        trainset, devset, testset = [],[],[]
        slotDict = {}
        for intent in self.intentList:
            assert templateDict.get(intent), f"Not found [{intent}] in {self.templatePath}"
            datalist = self.__synthsis_training_data(templateDict,slotDict, intent,sampleN)
            trainList, devList, testList = self.__split_dataset(datalist)
            print(f"intent={intent},totalN={len(datalist)}, #trainList={len(trainList)},#devList={len(devList)}, #testlist={len(testList)}")
            
            trainset += trainList; devset+=devList; testset += testList
        print(f"#trainset={len(trainset)}, #devset={len(devset)}, #testset={len(testset)}")
        
        # save all dataset
        random.shuffle(trainset); random.shuffle(devset); random.shuffle(testset)
        for mode in ["train", "dev", "test"]:
            outdir = os.path.join(self.datasetDir, mode)
            os.makedirs(outdir, exist_ok=True)
            for i, filename in enumerate(["label", "seq.in", "seq.out"]):
                with open(os.path.join(outdir, filename), "w") as f:
                    f.write("\n".join([item[i] for item in eval(f"{mode}set")]))
        # save intent and slot label
        intentList = ["UNK"] + self.intentList
        slotList   = ["PAD","UNK","O"] 
        for slot in slotDict.keys():
            slotList += [f"B-{slot}", f"I-{slot}"]
        with open(os.path.join(self.datasetDir,"intent_label.txt"), "w") as f:
            f.write("\n".join(intentList))
        with open(os.path.join(self.datasetDir,"slot_label.txt"), "w") as f:
            f.write("\n".join(slotList))
        print(f"output dir = {self.datasetDir}")

if __name__=="__main__":
    """Usage
        python3 dataGenerate/dataSynthesis.py
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--templatePath", type=str, default="data/templates/template.json",
                        help="预定义模板路径")
    parser.add_argument("--datasetDir", type=str, default="data/generalQA",
                        help="预定义模板路径")
    parser.add_argument("--sampleN", type=int, default=10000,
                        help="每个源数据采样的样本数")
    args = parser.parse_args()
    datasynObj = DataSynthesis(args.templatePath, args.datasetDir)
    datasynObj(args.sampleN)